0.0.1 Loading required packages

rm(list=ls())
set.seed(1)
knitr::opts_chunk$set(fig.height = 8, fig.width = 14)
if (!require(tidyverse)) {install.packages("tidyverse"); library(tidyverse)} else library(tidyverse)
if (!require(readxl)) {install.packages("readxl"); library(readxl)} else library(readxl)
if (!require(randomForest)) {install.packages("randomForest"); library(randomForest)} else library(randomForest)
if (!require(caret)) {install.packages("caret"); library(caret)} else library(caret)
if (!require(MLmetrics)) {install.packages("MLmetrics"); library(MLmetrics)} else library(MLmetrics)
if (!require(ggraph)) {install.packages("ggraph"); library(ggraph)} else library(ggraph)
if (!require(igraph)) {install.packages("igraph"); library(igraph)} else library(igraph)
if (!require(cluster)) {install.packages("cluster"); library(cluster)} else library(cluster)
if (!require(randomForestExplainer)) {
  install.packages("randomForestExplainer"); library(randomForestExplainer)
} else library(randomForestExplainer)
if (!require(phenoGraph)) {
  install.packages(devtools)
  devtools::install_github('yishaishimoni/phenoClust')
  library(phenoGraph)
} else library(phenoGraph)
source("vimPlot.R")
source("showTree.R")

0.0.2 Load in the data

Note- I am pulling this data from a local file which cannot be shared publicly at this time

dataset = readRDS("../data/rf_data.rds")
head(dataset %>% select(-IDENT_SUBID)) # delete subject id as well

Some Requirements:

1. The dataset must not contain any NAs/Inf/-Inf
    - Alternatively, use the `na.action` argument of the `randomForest` function to drop NAs
    - Removing NAs in advance, however, allows us to keep track of who is in the model.
2. If predicting a binary or multinomial categorical outcome, the outcome variable must be a factor.
    - Here we are not.
if (all(complete.cases(dataset))) print("Ok")
## [1] "Ok"

1 Fit a random forest model

Here we run the random forest regression algorithm, using the default hyperparameters, on the full dataset to predict RCADS anxiety total score from the neurobiological predictors and COMP/PI group.

forest.model = randomForest(
  formula = rcads_anxietyTotal_T ~.-IDENT_SUBID, data = dataset, # remove subject id from formula
  # In addition, we'll ask for the following outputs:
  ### -The variable importance metrics
  importance = TRUE,
  ### -The local variable importances
  localImp = TRUE,
  ### -The proximity matrix of the "in bag" data
  proximity = TRUE
)

1.1 View the results

# source(vimPlot.)
vimPlot(forest.model)

2 Tuning hyperparameters of the random forest model

  • Hyperparameters to tune:
  1. mtry = the number of randomly-selected variables with which to build each tree
  2. ntree = the number of trees in the ensemble/forest
  3. nodesize = the minimum size (the # of subjects) landing in each terminal node/leaf
  4. maxnodes = the maximum number of terminal nodes

Other parameters we will be interested to use given the imbalanced group-structure of our data: a. strata = the name of a factor variable giving the grouping of participants b. sampsize = the size of the bootstrap resample (or resamples if multiple groups are named above)

2.1 caret random forest tuning

The caret package will be used next in order to quickly tune the mtry parameter (while holding all other hyperparameters constant, becuase out of the box caret only supports tuning mtry, but does it quickly and efficiently.

# setup some instructions to give to the caret for training:
trCtrl <- caret::trainControl(
  method='repeatedcv',
  number = 3, # 3-Fold cross-validation
  repeats = 10, # repeated and averaged 10-times
  search='grid' # utilize a grid-search procedure
)
# now create the "grid" (in this case just a vector) with all the values of mtry we wish to test out
# note, the name of the variable within the tunegrid must be named "mtry", because that is what
# caret's random forest model will call its internal mtry hyperparameter
tunegrid <- base::expand.grid(
  mtry = seq(from = 1, 
             to = ncol(dataset) - 3, # the number of predictors
             by = 3
  )
)

# Final note, I'm going to register multiple cores (compute in parallel multiple models)
# This will make it go faster
#
library(doParallel)
n_cores <- makeCluster(detectCores()-10) # relative to number of cores on server
registerDoParallel(cores = n_cores)

rf_gridsearch <- caret::train(rcads_anxietyTotal_T ~ .-IDENT_SUBID, # enter a formula
                       data = dataset, # supply the data
                       method = 'rf', # tell caret to train a 'rf'='random forest'
                       metric = 'RMSE', maximize = FALSE, # optimize root-mean-squared-error
                       trControl = trCtrl, # here is our list of specifications
                       tuneGrid = tunegrid) # here is our list of mtry to test

stopCluster(n_cores)
#saveRDS(rf_gridsearch, "../output/rf_gridsearch_caret.rds")

2.1.1 Results

rf_gridsearch
## Random Forest 
## 
## 132 samples
##  70 predictor
## 
## No pre-processing
## Resampling: Cross-Validated (3 fold, repeated 10 times) 
## Summary of sample sizes: 88, 88, 88, 88, 88, 88, ... 
## Resampling results across tuning parameters:
## 
##   mtry  RMSE      Rsquared   MAE     
##    1    7.385175  0.2013098  5.862723
##    4    7.283074  0.2071268  5.799415
##    7    7.318180  0.2189201  5.781924
##   10    7.251643  0.2272647  5.721332
##   13    7.249758  0.2309968  5.701007
##   16    7.222144  0.2278395  5.683854
##   19    7.209592  0.2238919  5.678256
##   22    7.159217  0.2257578  5.617942
##   25    7.253341  0.2260053  5.673619
##   28    7.293618  0.2274835  5.680433
##   31    7.206866  0.2258583  5.626489
##   34    7.287573  0.2272372  5.654242
##   37    7.264992  0.2354006  5.644090
##   40    7.248651  0.2297494  5.606282
##   43    7.258401  0.2346982  5.619343
##   46    7.226551  0.2336080  5.600279
##   49    7.260177  0.2306978  5.621932
##   52    7.188696  0.2297462  5.576328
##   55    7.218223  0.2310783  5.572045
##   58    7.210000  0.2255436  5.587806
##   61    7.214233  0.2255212  5.592853
##   64    7.274324  0.2219881  5.637354
##   67    7.297380  0.2316139  5.625393
## 
## RMSE was used to select the optimal model using the smallest value.
## The final value used for the model was mtry = 22.

2.2 Manual tuning using the randomForest package

MTRY = rf_gridsearch$bestTune$mtry
# set up our grid to find the optimal number of trees and the best nodesize
params = as.data.frame(expand.grid(NTREE = seq(500, 7500, by = 250),
                                   NODESIZE = c(1:10)))
head(params)

Wrap randomForest() in a function that runs the algorithm, but first splits the data into a training and validation set; and subsequently outputs the test-set metrics.

# Set up a function to interpret our parameter grid
rfTune = function(my_data, param_grid) {
  # first we'll divide the data into a training and test sets
  train.inds = sample(1:nrow(my_data), replace = FALSE, size = round(nrow(my_data)*.7)) # 70/30 split
  y_train= my_data[train.inds, ]$rcads_anxietyTotal_T
  X_train= my_data[train.inds, ] %>% select(-IDENT_SUBID, -rcads_anxietyTotal_T)
  y_test = my_data[-train.inds,]$rcads_anxietyTotal_T
  X_test = my_data[-train.inds,] %>% select(-IDENT_SUBID, -rcads_anxietyTotal_T)
  # now generate an RF for each combination of hyperparameters we wish to tune
  ### Each time output the test-set RMSE
  gridResults = mclapply(mc.cores = detectCores() - 8, # modify as needed if running remotely
    X = 1:nrow(param_grid), # the fn will operate on the values 1, 2, ... nrow(param_grid)
    FUN = function (i) {
      modelFit = randomForest(
        x = X_train, y = y_train, xtest = X_test, ytest = y_test,
        mtry = MTRY, # re-using the optimal mtry from the caret tuneing
        strata = X_train$GROUP,
        sampsize = min(table(X_train$GROUP))*n_distinct(X_train$GROUP),
        ntree = param_grid[i,]$NTREE, # the i-th ntree
        nodesize = param_grid[i,]$NODESIZE # the i-th ntree
      )
      results = modelFit$test
      return(results)
  }) # end the random forest function; close the mclapply()
  return(gridResults) # will be a list with all the results from each grid iteration
}

2.2.1 Run the function

rfTune_results = rfTune(my_data = dataset, param_grid = params)
length(rfTune_results)
saveRDS(rfTune_results,"../output/rfTune_res.rds")

2.2.2 Compile and view the results

crossVal_results = list()
for (i in 1:length(rfTune_results)) {
  crossVal_results[["ntree"]][i] = params$NTREE[i]
  crossVal_results[["nodesize"]][i] = params$NODESIZE[i]
  crossVal_results[["RMSE"]][i] = sqrt(tail(rfTune_results[[i]]$mse, 1))
  crossVal_results[["R.sq"]][i] = tail(rfTune_results[[i]]$rsq, 1)
}
crossVal_results = as.data.frame(crossVal_results)
head(crossVal_results)
saveRDS(crossVal_results, "../output/gridSearch_crossVal_res.rds")

2.2.3 Plot results

crossVal_results %>% mutate_at(vars(nodesize), ~factor(.)) %>%
  gather(key='fit_statistic', value = 'stat_value', -nodesize, -ntree) %>%
  ggplot(.) + theme(text = element_text(size = 20)) + 
  geom_line(aes(x = ntree, y = stat_value, group=nodesize, color=nodesize), alpha=.8) +
  geom_point(aes(x = ntree, y = stat_value, group=nodesize, color=nodesize)) +
  geom_smooth(method = 'glm', aes(x = ntree, y = stat_value), color="black") +
  facet_wrap(~fit_statistic, scales = 'free_y') + 
  theme(text = element_text(size = 20))
## `geom_smooth()` using formula 'y ~ x'

2.2.4 Save our best parameters

bestParams = crossVal_results[which.min(crossVal_results$RMSE), c("nodesize","ntree")]
bestParams$MTRY = rf_gridsearch$bestTune$mtry # from caret tuning
bestParams

3 Fitting our final model

This model will be inspected for interpretation post-hoc. As such, tell randomForest to retain lots of information, such as which participants land in the same terminal node as one another, and to conduct repeated permutations of variables to derive more stable variable importance metrics (VIMs).

final.model = randomForest(
  rcads_anxietyTotal_T ~ .-IDENT_SUBID, data = dataset,
  mtry = bestParams$MTRY,
  ntree = bestParams$ntree,
  nodesize = bestParams$nodesize,
  strata = dataset$GROUP,
  sampsize = min(table(dataset$GROUP)) * length(table(dataset$GROUP)),
  importance = T, localImp = T, proximity = T, 
  keep.forest = TRUE, # keep all of the ntree trees when the forest is done building
  keep.inbag = TRUE, # keep track of who is in the in-bag sample for each tree
  nPerm = 1000 # for more stable (permutation-tested) variable importances
) # may take some time to run
saveRDS(final.model, "../output/final_RF_model.rds")

3.1 Results

3.1.2 Interpretation using a VIM plot

vimplt = vimPlot(final.model)

Notably, the error bars on the variable importances are much smaller, due to the repeated permutation procedure.

3.1.3 Visualize a tree in the ensemble

We’ll visualize the smallest tree to make the visualization more maneageable

source("showTree.R")
if (require(ggraph) & require(igraph)) {
  try(assign("treePlot", showTree(final_model = final.model), pos=.GlobalEnv))
}

The above provides a sample tree structure, but note: it does not inform us of the representativeness of this extracted tree for the rest of the trees in the forest model. Each tree is built on a random bootstrap sample of the data, and its structure may be heavily influenced by that randomness. As a case in point, this tree’s first split is not at the variable GROUP, which is overall the most important variable.

4 Interpretation using randomForestExplainer

4.1 Variable importance plots

To better understand which variables are most based on a number of criteria, such as the number of split points within a tree that occur at a given variable, in addition to that variables influence on predictive accuracy, we use the randomForestExplainer package.

if (require(randomForestExplainer)) {
  plot_min_depth_distribution(
    (min_depth_distribution(final.model))
  )
}

The plot shows on the x-axis the number of trees in which a variable (given on the y-axis) appeared as the root of the tree, and subsequently terminated after only “minimal depth” subsequent splits (different numbers of minimal depths are denoted by the color of the horiontal bar).

4.2 Interactions between predictors

Rather than looking at how deep the entire tree grows following an interesting root node, an interaction between just two variables can be represented in a similar fashion.

A meaningful interaction between two variables occurs when a root node is followed by very few splits of another variable. Thinking of it this way, a “perfect” interaction would split subjects into 4-categories: high and high on both variables, low and low on both, low on one variable and high on the other, and vice-versa. So here we can visualize the number of splits (of only one variable) it takes for a tree to reach termination following a root node of different variables.

if (require(randomForestExplainer)) {
  # we'll just output the interactions with the GROUP variable, because we know it
  # is the most common root variable
  plot_min_depth_interactions(
    (min_depth_interactions(final.model, "GROUP"))
  )
}

4.3 More ways to visualize interactions

4.3.1 Predicted outcome in the domain of two predictors (and the grouping factor)

Here we explore the interaction between: 1. RightAmygdalaWITHLeftAccumbens: This variables contains the functional correlation, or co-activation, of the right amygdala with the left accumbens 2. Left.Accumbens.area: This variable contains the volume (adjusted for age and gender) of the participant’s left accumbens area

# Simulate data from the actual distribution of the best variables
## Uniformly fill the feature space of each variables from its min to its max observed values
sims = as.data.frame(expand.grid(
  "RightAmygdalaWITHLeftAccumbens" = seq(from=min(dataset$RightAmygdalaWITHLeftAccumbens),
                                         to=max(dataset$RightAmygdalaWITHLeftAccumbens),
                                         length.out = 100),
  "Left.Accumbens.area" = seq(from=min(dataset$Left.Accumbens.area),
                              to=max(dataset$Left.Accumbens.area), 
                              length.out = 100)
))

# Compute the means of all other variables, so as to control for them when exploring the above int.
means = dataset %>%
  select(-GROUP, -IDENT_SUBID,-RightAmygdalaWITHLeftAccumbens,-Left.Accumbens.area) %>%
  summarize_all(.funs = ~mean(.))

# generate a dataset that repeats each row, so all simulated subjects express the mean value of all variables except for the two we will examine above
i = 1
while (i < nrow(sims)) {
  means = rbind(means, means[1,]) # add one more row
  i = i + 1
}

# paste the data frames together, and also duplicate the entire result and arbitrarily name half of the participants COMP and half PI so that we can see how the model predicts differently within each group. (Also generate a fake IDENT_SUBID column so that the new data has the same exact columns as that which we built the model on).
newData = cbind(sims, means) 
newData = rbind(newData, newData) %>% 
  mutate("IDENT_SUBID" = as.character("id"), GROUP = factor(rep(c("PI","COMP"), each=nrow(.)/2)))

# get the predicted values for these simulated subjects
newData$rcad_predictions = predict(final.model, newdata = newData)
saveRDS(newData, "../output/sim_data_for_plotting.rds")
#Plot
ggplot(newData) + theme(text= element_text(size = 20)) +
  geom_tile(aes(x = RightAmygdalaWITHLeftAccumbens,
                y = Left.Accumbens.area,
                fill = rcad_predictions)) +
  scale_fill_viridis() +
  facet_grid(~GROUP) + 
  theme(text = element_text(size = 20))

This plot emphasizes the interaction between the two variables (and shows that it is to some extent a three-way interaction involving GROUP). The shade of the plot surface represents rcads anxiety score predicted by the final model, at each different level of the feature space of the two feautres plotted, as well as its covariation with the grouping factor.

5 Clustering with Random forests

The proximity matrix that results from the random forest model comprises an n by n matrix which quantifies the proportion of trees in which (every combination of) two subjects end up in the same terminal node. Note, we constrained our terminal node size to some optimal value (5) above, and so in any given tree at least five subjects will all be in the same terminal node. The proportion of times this occurs between any two given subjects reflects how similar those subjects in terms of the model output. Thus, this matrix may be treated as a distance/(dis)similarity matrix, which is the standard input for almost any clustering paradigm.

5.1 Unsupervised Clustering

The first method for random forest clustering is to simply provide the randomForest with an X matrix containing predictors, but no response variable. The algorithm will then treate the X matrix as a multivariate normal distribution in as many dimensions as there are predictors. From this distribution, the RF generates a sample of synthetic data points (the same size as the original sample) and appends this synthetic data to the observed sample. Finally, a forest of trees is built to attempt to classify the real and synthetic data points as either real, or synthetic. The resultant proximity matrix presents the distances beteween points with respect only to the multivariate locus of predictor variables. In this way, the approach is similar to a traditional clustering algorithm since it omits an outcome variable.

# for this model, we'll use the "iris" dataset that comes pre-loaded into R b/c it contains clusters
data("iris")
unsupervised.forest = randomForest(
  x = iris %>% select(-Species),
  proximity = TRUE
)

Now that we have the proximity matrix, deciding how to cluster it is important. We’ll use the cluster package to conduct “partitioning around medoids” using the pam() function, as well as hierarchical agglomerative clustering using the hclust() function in the stats package to see if we can recreate the true Species variable in our clusters, or if we find something else.

# pam clustering:
pam.clustering = pam(x = unsupervised.forest$proximity, k=3, diss = TRUE)
# hierarchical clustering:
hclustering = hclust(as.dist(unsupervised.forest$proximity), method = 'ward.D2')
# get the clusters by cutting the dendogram
hclusters = cutree(hclustering, k = 3)

5.1.1 Plotting

plt_dat = data.frame(
  'clusterType' = factor(rep(c("pam", "hclust"), each = nrow(iris)),levels = c("pam","hclust")), 
  'cluster' = factor(c(pam.clustering$clustering, hclusters)),
  'Species'= factor(rep(iris$Species, times = 2)), 
  # include the first and second principal component rotations on which to plot the data
  # just for visualization
  'PC1'= rep(prcomp(select(iris,-Species), retx = T, scale. = T, center = T)$x[,1], times=2),
  'PC2'= rep(prcomp(select(iris,-Species), retx = T, scale. = T, center = T)$x[,2], times = 2)
)
ggplot(plt_dat) + theme(text = element_text(size = 20)) +
  geom_point(aes(x = PC1, y = PC2, color = cluster, shape = Species)) +
  facet_grid(~clusterType) +
  theme(text = element_text(size = 20))

The silhouette width of the clusters:

c(
  'pam'= colMeans(silhouette(x=pam.clustering$clustering, unsupervised.forest$proximity))["sil_width"],
  'hclust'= colMeans(silhouette(x=hclusters, unsupervised.forest$proximity))["sil_width"]
)
##    pam.sil_width hclust.sil_width 
##      -0.71567918       0.03973574

5.2 Consensus matrix clustering

The next method through which one may use the random forest model to cluster particpants is through consensus clustering. This is because the “proximity matrix” output of the random forest is in effect a consensus matrix: it represents the proportion of times any two subjects end up in the same cluster. Here the consensus matrix counts the proportion of trees in which any two subjects end up in the same terminal leaf, but this is directly analogous to a consensus matrix that counts the number of iterations in a repeated clustering paradigm that any two subjects end up in the same cluster. Thus, we can use community-detection algorithms on the consensus (or proximity matrix) to find smaller and more cohesive clusters.

# as indicated above, our proximity matrix IS a consensu matrix
# Convert it to a 'dist' object and pass it along to the community-detection algorithm
phenoClusters = phenoClust(D = as.dist(unsupervised.forest$proximity))

5.2.1 Plotting

plt_dat %>% filter(clusterType == "pam") %>% # just take one half of the long format data
  mutate(community = factor(phenoClusters$C)) %>%
ggplot(.) + theme(text = element_text(size = 20)) +
  geom_point(aes(x = PC1, y = PC2, color = community, shape = Species)) +
  theme(text = element_text(size = 20))